vae_config = {
    "sample_size": 512,
    "in_channels": 3,
    "out_channels": 3,
    "down_block_types": [
        "DownEncoderBlock2D",
        "DownEncoderBlock2D",
        "DownEncoderBlock2D",
        "DownEncoderBlock2D",
        "DownEncoderBlock2D",
        "DownEncoderBlock2D",
    ],
    "up_block_types": [
        "UpDecoderBlock2D",
        "UpDecoderBlock2D",
        "UpDecoderBlock2D",
        "UpDecoderBlock2D",
        "UpDecoderBlock2D",
        "UpDecoderBlock2D",
    ],
    "block_out_channels": [64, 128, 256, 256, 512, 512],
    "layers_per_block": 2,
    "act_fn": "silu",
    "latent_channels": 16,
    "norm_num_groups": 32,
    "compression_range": 4,
    "starting_block_idx": 1,
    "scaling_factor": 0.18215,
    "mid_block_attention_head_dim": 1,
    "mid_block_processing": "conv",
    "num_layers": 8,
    "metric": "jpeg",
    "mid_block_supervision": False,
}
from nested_vae import MultiScaleAutoencoderKL
import torch
import yaml
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from tqdm import tqdm
import json
from matplotlib.pyplot import figure
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import random
import torch.nn.init as init
import json

def set_deterministic(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed) 
    torch.cuda.manual_seed_all(seed)

torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
size = 512
train_transform = transforms.Compose([
        transforms.Resize((size,size)),
        transforms.ToTensor(),
])
val_transform = transforms.Compose([
        transforms.Resize((size,size)),
        transforms.ToTensor(),
])
set_deterministic(0)

vae = MultiScaleAutoencoderKL.from_config(vae_config)
reloaded_model = torch.load(
     "checkpoints/sep6redis_1000000/consolidated.pth"
)
vae.load_state_dict(reloaded_model)


def getdata(task, batch_size = 128):
    set_deterministic(0)

    data_dir='./dataset'
    
    if task == "DTD":
        train_dataset = datasets.DTD(
            root=data_dir,split="train",
            download=True, transform=train_transform,
        )
        validation_dataset = datasets.DTD(
            root=data_dir,split="test",
            download=True, transform=val_transform,
        )
        num_classes=47
    if task == "GTSRB":
        train_dataset = datasets.GTSRB(
            root=data_dir,split="train",
            download=True, transform=train_transform,
        )
        validation_dataset = datasets.GTSRB(
            root=data_dir,split="test",
            download=True, transform=val_transform,
        )
        num_classes=43
    if task == "SVHN":
        train_dataset = datasets.SVHN(
            root=data_dir,split="train",
            download=True, transform=train_transform,
        )
        validation_dataset = datasets.SVHN(
            root=data_dir,split="test",
            download=True, transform=val_transform,
        )
        num_classes=10
    train_dataset.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.Resize((size,size)),
            transforms.ToTensor(),
    ])
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size,shuffle=True
    )
    validation_loader = torch.utils.data.DataLoader(
        validation_dataset, batch_size=batch_size,
    )
    return train_loader, validation_loader, num_classes



class Model(nn.Module):

    def __init__(self,vae, D, num_classes,device,dropout_rate):
        super(Model, self).__init__()
        self.custom_embedding = vae.bfloat16().cuda()
        self.linear = nn.Linear(D,num_classes)
        
        for l in [self.linear]:
            y = 1.0 / np.sqrt(D)
            init.uniform_(l.weight, -y, y)
            init.constant_(l.bias, 0)

    def forward(self, x):
        x = self.custom_embedding.encode_image(x).reshape(x.shape[0],-1)
        x = self.linear(x)
        return x

def calculate_accuracy(outputs,labels):
    _, predicted = torch.max(outputs.data, 1)
    total = labels.size(0)
    correct = (predicted == labels).sum().item()
    return correct/total


num_epochs = 20
for task in ["DTD","GTSRB","SVHN"]:
    for wd in [0.1]:
        for drop in [0.1]:
            set_deterministic(0)
        
            train_loader, validation_loader, num_classes=getdata(task)
            criterion = nn.CrossEntropyLoss()
            
            D = 16*32*32
            model = Model(vae, D,num_classes,device,drop)
            model = model.to(device)
            for name, param in model.named_parameters():
                if "custom_embedding" in name:
                    param.requires_grad = False
            optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, betas=(0.9,0.95),weight_decay=wd)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0)

            torch.cuda.empty_cache()
        
            # import tqdm
            train_loss_l = []
            train_accuracy_l = []   
            
            val_loss_l = []
            val_accuracy_l = []
            
            val_best=-np.inf
            model = model.bfloat16().cuda()
            optimizer.zero_grad()
            for epoch in range(num_epochs):
                with tqdm(train_loader, unit="batch") as tepoch:
                    tepoch.set_description('Epoch '+str(epoch+1)+'/'+str(num_epochs))
                    
                    train_loss = 0
                    train_accuracy = 0
                    val_accuracy = 0
                    val_loss = 0
                    numstep=0
                    model.train()
                    
                    for images, labels in tepoch:
                        numstep+=1
            
                        images = images.to(device).bfloat16()*2-1
                        labels = labels.to(device)
            
                        outputs = model(images)
                        loss = criterion(outputs, labels)
                        loss.backward()
                        
                        optimizer.step()
                        optimizer.zero_grad()    
                        
                        train_loss += loss.item()
                        train_accuracy += calculate_accuracy(outputs, labels)
                        
                        torch.cuda.empty_cache()
                        del images, labels, outputs

                    train_loss = train_loss/len(tepoch)
                    train_accuracy = train_accuracy/len(tepoch)
                    
                    # Validation
                    model.eval()
                    with torch.no_grad():
                        correct = 0
                        total = 0
                        for images, labels in validation_loader:
                            images = images.to(device).bfloat16()*2-1
                            labels = labels.to(device)
                            outputs = model(images)
                            val_loss += criterion(outputs, labels).item()
                            val_accuracy += calculate_accuracy(outputs, labels)
                            del images, labels, outputs
                            
                        val_accuracy = val_accuracy / len(validation_loader)
                        val_loss = val_loss / len(validation_loader)
                    
                    tepoch.close()    
                    print(task, wd,drop,'Train loss: ', train_loss, ' - Train accuracy: ', train_accuracy,' - Val loss: ', val_loss, ' - Val accuracy: ', val_accuracy)
                                
                    train_accuracy_l.append(train_accuracy)
                    train_loss_l.append(train_loss)
                    
                    val_accuracy_l.append(val_accuracy)
                    val_loss_l.append(val_loss)

                    scheduler.step()
            
            with open(f'log/{task}_train_loss.json', 'w') as outfile:
                json.dump(train_loss_l, outfile)
            with open(f'log/{task}_val_loss.json', 'w') as outfile:
                json.dump(val_loss_l, outfile)
            with open(f'log/{task}_train_acc.json', 'w') as outfile:
                json.dump(train_accuracy_l, outfile)
            with open(f'log/{task}_val_acc.json', 'w') as outfile:
                json.dump(val_accuracy_l, outfile)
